import numpy as np
import sys
import copy
import matplotlib.pyplot as plt
from evo.core.trajectory import PoseTrajectory3D
from evo.core import sync
from evo.core import lie_algebra as lie
import os

def load_tum_to_pose_traj(filepath):
    timestamps = []
    positions_xyz = []
    orientations_quat_wxyz = []
    with open(filepath, 'r') as f:
        for line in f:
            if line.startswith("#") or line.strip() == "":
                continue
            parts = list(map(float, line.strip().split()))
            if len(parts) != 8:
                continue
            t, tx, ty, tz, qx, qy, qz, qw = parts
            orientations_quat_wxyz.append([qw, qx, qy, qz])
            positions_xyz.append([tx, ty, tz])
            timestamps.append(t)
    return PoseTrajectory3D(positions_xyz=positions_xyz,
                            orientations_quat_wxyz=orientations_quat_wxyz,
                            timestamps=np.array(timestamps))

def compute_recall_from_file(gt_path, est_path, thresh_cm=10, save_path=False, max_diff=0.01):
    traj_ref = load_tum_to_pose_traj(gt_path)
    traj_est = load_tum_to_pose_traj(est_path)

    traj_ref_sync, traj_est_sync = sync.associate_trajectories(traj_ref, traj_est, max_diff=max_diff,)

    traj_est_aligned = copy.deepcopy(traj_est_sync)
    r_a, t_a, s = traj_est_aligned.align(traj_ref_sync, correct_scale=True)

    # traj_est_aligned.scale(s)
    # traj_est_aligned.transform(lie.se3(r_a, t_a))

    thresh_m = thresh_cm / 100.0
    matches = 0
    gt_positions = traj_ref_sync.positions_xyz
    est_positions = traj_est_aligned.positions_xyz

    for gt_pos in gt_positions:
        distances = np.linalg.norm(est_positions - gt_pos, axis=1)
        if np.min(distances) < thresh_m:
            matches += 1

    total_poses = len(gt_positions)
    recall_percentage = 100.0 * matches / total_poses if total_poses > 0 else 0.0

    if save_path:
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.plot(gt_positions[:, 0], gt_positions[:, 1], '--', label='Ground Truth', color='black')
        ax.plot(est_positions[:, 0], est_positions[:, 1], '-', label='Aligned Estimate', color='blue')
        ax.legend()
        ax.set_title("Trajectory Alignment")
        ax.set_xlabel("x [m]")
        ax.set_ylabel("y [m]")
        ax.axis("equal")
        ax.grid(True)
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, bbox_inches="tight", dpi=300)
        plt.close()

    return matches, total_poses, recall_percentage

if __name__ == "__main__":
    quiet = "--quiet" in sys.argv
    args = [arg for arg in sys.argv[1:] if not arg.startswith("--")]

    if len(args) < 2:
        print("Usage: python compute_recall.py [--quiet] groundtruth.txt estimated.txt [output_plot.png]")
        sys.exit(1)

    gt_file = args[0]
    est_file = args[1]
    output_plot = args[2] if len(args) > 2 else "trajectory_plot.png"

    matches, total, percentage = compute_recall_from_file(gt_file, est_file, save_path=output_plot)

    if quiet:
        print(f"{percentage:.2f}")
    else:
        print(f"GT poses recalled within 10cm: {matches} / {total} ({percentage:.2f}%)")
        print(f"Plot saved to: {output_plot}")
